import numpy as np
import os
from PIL import Image
import glob
from sklearn.decomposition import PCA
from sklearn.random_projection import SparseRandomProjection
import dataclasses
from typing import Callable, Generic, TypeVar
from statsmodels.stats.multitest import fdrcorrection
from himalaya.scoring import correlation_score
import torch
from himalaya.kernel_ridge import (
    KernelRidgeCV,
)
from himalaya.ridge import RidgeCV
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from concurrent.futures import ThreadPoolExecutor
import pickle
from concurrent.futures import ThreadPoolExecutor
import scipy
from tqdm import tqdm
import numpy as np
import scipy.io
from nsd_access import NSDAccess


try:
    import cupy as cp
    # GPUが利用可能かチェック
    if cp.cuda.runtime.getDeviceCount() > 0:
        use_gpu = True
    else:
        use_gpu = False
except (ImportError, cp.cuda.runtime.CUDARuntimeError):
    use_gpu = False
use_gpu = False

T = TypeVar("T")
U = TypeVar("U")
@dataclasses.dataclass
class TrnVal(Generic[T]):
    """Tuple of something existing for training and evaluation data.

    This may be a trained model, a path, or data, for example."""

    trn: T
    val: T

    def map_fn(self, f: Callable[[T], U]) -> "TrnVal[U]":
        """Applies f to trn and val."""
        return TrnVal(trn=f(self.trn), val=f(self.val))


def make_filename(reduce_dim, dataset="all"):
    if reduce_dim[0] == "pca":
        filename = f"{reduce_dim[1]}PCs"
    elif reduce_dim[0] == "srp":
        filename = f"{reduce_dim[1]}eps"
    else:
        filename = f"raw"

    if dataset != "all":
        filename = f"{filename}_{dataset}"
    
    return filename

def check_saved_score(scores_save_dir, reduce_dim, feat_path, dataset):
    if reduce_dim[0]=="pca":
        if dataset == "all":
            if os.path.exists(f"{scores_save_dir}/cc_{reduce_dim[1]}PCs.npy"):
                print(f"{feat_path} PCs' encoding results are already exist.")
                return True
            else:
                return False
        else:
            if os.path.exists(f"{scores_save_dir}/cc_{reduce_dim[1]}PCs_{dataset}.npy"):
                print(f"{feat_path} PCs' encoding results are already exist.")
                return True
            else:
                return False
        
    elif reduce_dim[0]=="srp":
        if dataset == "all":
            if os.path.exists(f"{scores_save_dir}/cc_{reduce_dim[1]}eps.npy"):
                print(f"{feat_path} SRP' encoding results are already exist.")
                return True
            else:
                return False
        else:
            if os.path.exists(f"{scores_save_dir}/cc_{reduce_dim[1]}eps_{dataset}.npy"):
                print(f"{feat_path} SRP' encoding results are already exist.")
                return True
            else:
                return False
        
    else:
        if dataset == "all":
            if os.path.exists(f"{scores_save_dir}/cc_raw.npy"):
                print(f"{feat_path} raw's encoding results are already exist.")
                return True
            else:
                return False
        else:
            if os.path.exists(f"{scores_save_dir}/cc_raw_{dataset}.npy"):
                print(f"{feat_path} raw's encoding results are already exist.")
                return True
            else:
                return False

def resize_image(file_path, new_file_path, width, height):
    with Image.open(file_path) as img:
        resized_img = img.resize((width, height))
        resized_img.save(new_file_path)

def resize_images(source_dir, target_dir, width, height):
    """
    Resize all images in a directory to the specified width and height,
    and save them in a new directory structure mirroring the original.
    """
    if not os.path.exists(source_dir):
        print(f"Source directory {source_dir} does not exist.")
        return
    
    for root, dirs, files in os.walk(target_dir):
        if files:
            print("Target directory already exists.")
            return
    
    print("Resizing images...")
    os.makedirs(target_dir, exist_ok=True)

    tasks = []
    with ThreadPoolExecutor() as executor:
        for root, dirs, files in os.walk(source_dir):
            # Creating corresponding directories in the target directory
            for dir in dirs:
                new_dir_path = os.path.join(root, dir).replace(source_dir, target_dir)
                if not os.path.exists(new_dir_path):
                    os.makedirs(new_dir_path)

            # Resizing images and saving them in the corresponding location
            for file in files:
                if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.gif')):
                    file_path = os.path.join(root, file)
                    new_file_path = file_path.replace(source_dir, target_dir)
                    tasks.append(executor.submit(resize_image, file_path, new_file_path, width, height))

        # Wait for all tasks to complete
        for task in tasks:
            task.result()



def fit_pca(stim, num_pcs, return_projector=False):
    rng = 42
    pca = PCA(random_state=rng, n_components=num_pcs)
    projector = pca.fit(stim.trn)
    stim_trn_pca = projector.transform(stim.trn)
    stim_val_pca = projector.transform(stim.val)
    Stim = TrnVal(trn=stim_trn_pca, val=stim_val_pca)
    if return_projector:
        return Stim, projector
    else:
        return Stim

def fit_srp(stim, epsilon, return_projector=False):
    """
    Fit and transform using Sparse Random Projection
    The number of dimensions after fit is determined by the number of samples at the time of fitting and the value of epsilon
    Example 1: If you fit with 22,200 samples and set epsilon to 0.1, the number of dimensions after transformation will be 8,578.
    Example 2: If you fit with 22,200 samples and set epsilon to 0.2, the number of dimensions after transformation will be 2,309.
    """
    rng = 42 # TODO もし他にもランダムな要素があるならdata_const.pyなどでコントロールすべき。
    srp = SparseRandomProjection(random_state=rng,eps=epsilon)
    projector = srp.fit(stim.trn)
    stim_trn_srp = projector.transform(stim.trn)
    stim_val_srp = projector.transform(stim.val)
    Stim = TrnVal(trn=stim_trn_srp, val=stim_val_srp)
    if return_projector:
        return Stim, projector
    else:
        return Stim

def search_best_layer(score_dir, filename, select_topN="all"):
    max_accuracy = 0
    max_std_error = 0
    if os.path.exists(score_dir):
        for layer in os.listdir(score_dir):
            layer_dir = os.path.join(score_dir, layer)
            cv_score_path = os.path.join(layer_dir, f'cv_cc_{filename}.npy')
            score_path = os.path.join(layer_dir, f'cc_{filename}.npy')
            if os.path.isfile(cv_score_path) and os.path.isfile(score_path):
                cv_scores = np.load(cv_score_path)
                cv_accuracy = cv_scores.mean()
                if cv_accuracy > max_accuracy:
                    max_accuracy = cv_accuracy
                    max_layer = layer
                    cc_scores = np.load(score_path)
                    if select_topN != "all":
                        cc_scores = np.sort(cc_scores)[::-1][:select_topN]
                    max_std_error = cc_scores.std() / np.sqrt(len(cc_scores))
                    max_layer_accuracy = cc_scores.mean()
        return max_layer, max_layer_accuracy, max_std_error
    
    else:
        raise ValueError(f"{score_dir} does not contain scores for each layer.")
    
def load_frames(frames_topdir, dataset_name):
    if dataset_name == "nsd":
        img_type = "png"
        frames_paths = [frame_path for frame_path in glob.glob(os.path.join(frames_topdir, "*")) if frame_path.endswith(img_type)]
        frames_paths = sorted(frames_paths)
        frames_dirs = {"nsd": frames_paths}
        
        return frames_dirs
    
    else:
        frames_dirs = [dir_path for dir_path in glob.glob(os.path.join(frames_topdir, "*")) if os.path.isdir(dir_path)]
        frames_dirs = sorted(frames_dirs, key=lambda x: int(os.path.basename(x).split("_")[-1]))
        img_type = "jpg"
            
        frames_dict = {}
        for dir_path in frames_dirs:
            dir_name = os.path.basename(dir_path)        
            img_paths = glob.glob(os.path.join(dir_path, f"*.{img_type}"))
            frames_dict[dir_name] = sorted(img_paths)
            # print(f"{len(img_paths)} images in {dir_path}")

        return frames_dict


def gen_nulldistrib_gauss(nvoxels: int, 
                          valnum: int) -> list[np.ndarray]:
    rccs = []

    # Max num of cortex voxels = 400 x 400
    a = np.random.randn(400, valnum)
    b = np.random.randn(400, valnum)
    rccs = np.corrcoef(a, b)
    rccs = rccs[400:, :400].ravel()
    rccs = rccs[:nvoxels]

    return rccs

def gen_nulldistrib_block(resp_true: torch.Tensor,
                          resp_pred: np.ndarray,
                          block_size = 10,
                          num_iterations = 1,
                          device = 0) -> list[np.ndarray]:
    """
    Block permutation: c.f. Tang et al., 2023 NeurIPS https://arxiv.org/abs/2305.12248
    Note that the implementation slightly differs from Tang et al. for computational reasons:
    - This function performs pure block permutation without replacement.
    - The default value of num_iterations is set to 1, so not creating a null distribution for each voxel.
    """
    
    np.random.seed(42)
    num_trials = resp_true.shape[0]
    last_block_size = num_trials % block_size
    rccs = []
    resp_pred = torch.Tensor(resp_pred).to(device)
    for i in range(num_iterations):
        if last_block_size == 0:
            indices = np.arange(num_trials).reshape(-1, block_size).tolist()
        else:
            l1 = np.arange(num_trials - last_block_size).reshape(-1, block_size).tolist()
            l2 = [np.arange(num_trials - last_block_size, num_trials).tolist()]
            indices = l1 + l2

        np.random.shuffle(indices)
        shuffled_indices = np.concatenate(indices)
        resp_true_shuffle = resp_true[shuffled_indices, :]
        # Compute linear correlation
        rcc = correlation_score(resp_pred, resp_true_shuffle)
        rccs.extend(rcc.detach().cpu().numpy())

    return rccs

def fdr_correction(ccs: np.ndarray,
                   rccs: np.ndarray,
                   ) -> np.ndarray:
    # Make random correlation coefficient histogram
    nvoxels = len(ccs)

    px = []
    for i in range(nvoxels):
        x = np.argwhere(rccs > ccs[i])
        px.append(len(x) / nvoxels)

    significant_voxels, pvalue_corrected = fdrcorrection(px, alpha=0.05, method="indep", is_sorted=False)
    if sum(significant_voxels) > 0:
        print(f"Minimum R of siginificant voxels = {np.min(ccs[significant_voxels])}")
        print(f"Maximum R of siginificant voxels = {np.max(ccs[significant_voxels])}")
        print(
            f"Number of voxels with significant positive correlation: {len(np.where(significant_voxels)[0])}"
        )
    else:
        print("No voxels are significant")

    return significant_voxels, pvalue_corrected



def make_himalaya_pipeline(n_samples,
                           n_features,
                           cv,
                           alpha,
                           score_func):
    if n_samples >= n_features:
        print("Solving Ridge regression...")
        ridge = RidgeCV(
            alphas=alpha, cv=cv, solver_params={"score_func": score_func,
                                                "n_targets_batch":3000,
                                                "n_alphas_batch":20,
                                                "n_targets_batch_refit":200,
            }
        )

    else:
        print("Solving Kernel Ridge regression...")
        ridge = KernelRidgeCV(
            alphas=alpha, cv=cv, solver_params={"score_func": score_func,
                                                "n_targets_batch":10000,
                                                "n_alphas_batch":200,
                                                "n_targets_batch_refit":2000,
            }
        )
    preprocess_pipeline = make_pipeline(
        StandardScaler(with_mean=True, with_std=False),
    )
    pipeline = make_pipeline(
        preprocess_pipeline,
        ridge,
    )

    return pipeline


def collect_fmri_byroi_for_nsd(subject,
                       trainvalid,
                       atlasname,
                       norm=False):
    savedir = f'./data/nsd/fmri/{subject}/'
    trainvalid_savename = "tr" if trainvalid == "TRAIN" else "te"
    
    if atlasname == "cortex":
        if norm:
            betas_all = np.load(f'{savedir}/{subject}_cortex_betas_ave_{trainvalid_savename}_norm.npy')
        else:
            betas_all = np.load(f'{savedir}/{subject}_cortex_betas_ave_{trainvalid_savename}.npy')
        print(betas_all.shape)
        return betas_all
    
    else:
        nsda = NSDAccess('./data/NSD')
        atlas = nsda.read_atlas_results(subject=subject, atlas=atlasname, data_format='func1pt8mm')
        betas_all = []
        for roi,val in atlas[1].items():
            print(roi,val)
            if val == 0:
                print('SKIP')
                continue
            # Load
            betas_ave = np.load(f'{savedir}/{subject}_{roi}_betas_ave_{trainvalid_savename}.npy')
            print(betas_ave.shape)
            betas_all.append(betas_ave)
        betas_all = np.hstack(betas_all)
        print(betas_all.shape)
        return betas_all

def collect_stim_for_nsd(subject, modality, featdir, use_stim='ave', reduce_dim="default"):
    # savedir = f'{topdir}/subjfeat/'
    
    if modality == "image":
        if reduce_dim == "default":
            if os.path.exists(f'{featdir}/{subject}_{use_stim}_tr.npy'):
                print(f'{featdir}/{subject}_{use_stim}.npy already exists.')
                print(f'Loading this saved file...')
                feats_tr = np.load(f'{featdir}/{subject}_{use_stim}_tr.npy')
                feats_te = np.load(f'{featdir}/{subject}_{use_stim}_te.npy')
                print(f"Original features' shape: {feats_tr.shape}, {feats_te.shape}")
                return TrnVal(trn=feats_tr, val=feats_te)
        elif reduce_dim[0] == "pca":
            if os.path.exists(f'{featdir}/{subject}_{use_stim}_tr_{reduce_dim[1]}PCs.npy'):
                print(f'{featdir}/{subject}_{use_stim}_{reduce_dim[1]}PCs.npy already exists.')
                print(f'Loading this saved file...')
                feats_tr = np.load(f'{featdir}/{subject}_{use_stim}_tr_{reduce_dim[1]}PCs.npy')
                feats_te = np.load(f'{featdir}/{subject}_{use_stim}_te_{reduce_dim[1]}PCs.npy')
                print(f"Reduced features' shape: {feats_tr.shape}, {feats_te.shape}")
                return TrnVal(trn=feats_tr, val=feats_te)
        
        nsd_expdesign = scipy.io.loadmat('./data/NSD/nsddata/experiments/nsd/nsd_expdesign.mat')

        # Note that most of them are 1-base index!
        # This is why I subtract 1
        sharedix = nsd_expdesign['sharedix'] -1 

        if use_stim == 'ave':
            stims = np.load(f'./data/nsd/fmri/{subject}/{subject}_stims_ave.npy')
        else: # Each
            stims = np.load(f'./data/nsd/fmri/{subject}/{subject}_stims.npy')
        
        feats = []
        tr_idx = np.zeros(len(stims))

        for idx, s in enumerate(tqdm(stims)): 
            if s in sharedix:
                tr_idx[idx] = 0
            else:
                tr_idx[idx] = 1    
            feat = np.load(f'{featdir}/{s:06}.npy')
            feats.append(feat)

        feats = np.stack(feats)

        # os.makedirs(savedir, exist_ok=True)

        feats_tr = feats[tr_idx==1,:]
        feats_te = feats[tr_idx==0,:]
        print(f"Original features' shape: {feats_tr.shape}, {feats_te.shape}")

        np.save(f'./data/nsd/fmri/{subject}/{subject}_stims_tridx.npy',tr_idx)

        np.save(f'{featdir}/{subject}_{use_stim}_tr.npy',feats_tr)
        np.save(f'{featdir}/{subject}_{use_stim}_te.npy',feats_te)
        
        feats = TrnVal(trn=feats_tr, val=feats_te)
        
        if reduce_dim[0] == "pca":
            reduce_dim[1] = int(reduce_dim[1])
            feats, projector = fit_pca(feats, reduce_dim[1], return_projector=True)      
            try:
                np.save(f"{featdir}/projector_{subject}_{use_stim}_{reduce_dim[1]}PCs.npy", projector)
            # When the size of the projector is too large, save it as a pickle file.
            except:
                with open(f"{featdir}/projector_{subject}_{use_stim}_{reduce_dim[1]}PCs.pkl", 'wb') as f:
                    pickle.dump(projector, f, protocol=pickle.HIGHEST_PROTOCOL)
        elif reduce_dim[0] == "srp":
            NotImplementedError("SRP is not implemented yet.")
        
        else:
            return feats
        
        print(f"Reduced features' shape: {feats.trn.shape}, {feats.val.shape}")
        
        # Save reduced features
        np.save(f'{featdir}/{subject}_{use_stim}_tr_{reduce_dim[1]}PCs.npy',feats.trn)
        np.save(f'{featdir}/{subject}_{use_stim}_te_{reduce_dim[1]}PCs.npy',feats.val)
        
        return feats

    elif modality == "semantic":
        feat_tr_path = f"{featdir}/{subject}_tr.npy"
        feat_te_path = f"{featdir}/{subject}_te.npy"

        if reduce_dim[0] == "default":
            feat_tr = np.load(feat_tr_path)
            feat_te = np.load(feat_te_path)
            print(f"Original features' shape: {feat_tr.shape}, {feat_te.shape}")
            return TrnVal(trn=feat_tr, val=feat_te)

        elif reduce_dim[0] == "pca":
            if os.path.exists(f'{featdir}/{subject}_tr_{reduce_dim[1]}PCs.npy'):
                print(f'{featdir}/{subject}_{reduce_dim[1]}PCs.npy already exists.')
                print(f'Loading this saved file...')
                feats_tr = np.load(f'{featdir}/{subject}_tr_{reduce_dim[1]}PCs.npy')
                feats_te = np.load(f'{featdir}/{subject}_te_{reduce_dim[1]}PCs.npy')
                print(f"Reduced features' shape: {feats_tr.shape}, {feats_te.shape}")
                return TrnVal(trn=feats_tr, val=feats_te)
        
        feat_tr = np.load(feat_tr_path)
        feat_te = np.load(feat_te_path)
        feats = TrnVal(trn=feat_tr, val=feat_te)
        
        if reduce_dim[0] == "pca":
            reduce_dim[1] = int(reduce_dim[1])
            feats, projector = fit_pca(feats, reduce_dim[1], return_projector=True)      
            try:
                np.save(f"{featdir}/projector_{subject}_ave_{reduce_dim[1]}PCs.npy", projector)
            # When the size of the projector is too large, save it as a pickle file.
            except:
                with open(f"{featdir}/projector_{subject}_ave_{reduce_dim[1]}PCs.pkl", 'wb') as f:
                    pickle.dump(projector, f, protocol=pickle.HIGHEST_PROTOCOL)

        elif reduce_dim[0] == "srp":
            NotImplementedError("SRP is not implemented yet.")
        else:
            return feats
        
        print(f"Reduced features' shape: {feats.trn.shape}, {feats.val.shape}")
        
        # Save reduced features
        np.save(f'{featdir}/{subject}_tr_{reduce_dim[1]}PCs.npy',feats.trn)
        np.save(f'{featdir}/{subject}_te_{reduce_dim[1]}PCs.npy',feats.val)
        
        return feats
    
def create_volume_index_and_weight_map(subject_name,
                                       file_type,
                                       threshold,
                                       model_score_dir,
                                       target_best_cv_layer,
                                       filename,
                                       nsda,
                                       atlasnames):
    """
    file_typeやatlasの指定にもとづき、
    1. 選択するvoxelのインデックス (volume_index)
    2. volume_index と weight行列(層の係数)との対応関係 (weight_index_map)
    3. file_type != "full" の場合の top_voxels (target_top_voxels) もあわせて取得
    
    Parameters
    ----------
    subject_name : str
        サブジェクト名 (例: "subj01")
    file_type : str
        "full", "cc", "pvalues_corrected" など
    threshold : float
        ボクセルを切り出すためのしきい値
    model_score_dir : str
        モデルスコアが保存されているディレクトリパス
    target_best_cv_layer : str
        ベストレイヤー名
    filename : str
        追加で読み込むファイル名
    nsda : NSDAccess
        NSDのヘルパークラス(例: NSDAccess('./data/NSD'))
    atlasnames : list
        使用するatlasの名前のリスト (例: ["HCP_MMP1", "prf-visualrois"] など)
    
    Returns
    -------
    volume_index : ndarray
        cortexマスクをかけた後のボクセルインデックス配列
    weight_index_map : ndarray
        volume_index と weight行列のインデックス対応を作るマップ
    target_top_voxels : ndarray
        file_type != "full" の場合に選択されたボクセルインデックス
        (fullの場合は range(num_voxel) のみ作成し、呼び出し側が必要なら利用する)
    """
    # file_type が "full" ではない場合: スコアファイルを読み込んで閾値でボクセルを選別
    if file_type != "full":
        target_values = np.load(f"{model_score_dir}/{target_best_cv_layer}/{file_type}_{filename}.npy")
        if file_type == "cc":
            target_top_voxels = np.where(target_values > threshold)[0]
        elif file_type == "pvalues_corrected":
            target_top_voxels = np.where(target_values < threshold)[0]
        else:
            # 必要に応じて他のfile_typeパターンがあれば追加
            target_top_voxels = np.array([])
        print(f"Number of selected voxels: {len(target_top_voxels)}")

    # file_type が "full" の場合: atlas からすべてのボクセル数をカウント
    else:
        atlas = nsda.read_atlas_results(subject=subject_name, atlas=atlasnames, data_format='func1pt8mm')
        num_voxel = 0
        # atlasnames が1つの場合を想定するなら[0]でも良いが、複数時は工夫が必要。
        # ここでは便宜上、atlasnames[0] のatlasを読むことを想定。
        # もしループで処理するならコード追加が必要
        atlas_data = atlas[0]  # ボクセル値
        atlas_dict = atlas[1]  # ROI名 → 値
        for roi, val in atlas_dict.items():
            print(roi, val)
            if val == 0:
                print('SKIP')
                continue
            else:
                roi_index = np.where(atlas_data.transpose([2,1,0]) == val)[0]
                num_voxel += len(roi_index)
        print(f"Number of {atlasnames}'s voxels: {num_voxel}")
        target_top_voxels = range(num_voxel)

    # cortex上のmaskを用いて volume index を作成
    ctx_mask = cortex.db.get_mask(subject_name, "full")
    ctx_mask_flat = ctx_mask.flatten()
    ctx_mask_index = np.where(ctx_mask_flat)[0]

    # weight_index_map: 体積空間でのインデックスと weight(=係数) のインデックスを対応づける
    weight_index_map = np.full(len(ctx_mask_flat), -1, dtype=int)
    weight_index_map[ctx_mask_index] = np.arange(len(ctx_mask_index))

    # atlasname が "cortex" 以外の場合、atlasのROIとcortexマスクの論理積をとってvolume_indexを構築
    volume_index = None
    if atlasnames != "cortex" and isinstance(atlasnames, list):
        for atlasname in atlasnames:
            atlas = nsda.read_atlas_results(subject=subject_name, atlas=atlasname, data_format='func1pt8mm')
            atlas_data = atlas[0].transpose([2,1,0])
            atlas_dict = atlas[1]

            for roi, val in atlas_dict.items():
                # 例: HCP_MMP1 の場合はPCVが含まれるROIのみ使う、といった例外処理
                if atlasname == "HCP_MMP1":
                    if "PCV" not in roi:
                        continue
                print(roi, val)
                if val == 0:
                    print('SKIP')
                    continue
                else:
                    roi_index_flat = (atlas_data == val).flatten()
                    roi_in_mask = np.logical_and(roi_index_flat, ctx_mask_flat)
                    roi_mask_index = np.where(roi_in_mask)[0]

                    # すでに何かしら volume_index があれば union をとる
                    if volume_index is None:
                        volume_index = roi_mask_index
                    else:
                        volume_index = np.union1d(volume_index, roi_mask_index)
    else:
        # atlasname が "cortex" の場合など、単純に target_top_voxels を利用したいならこちら
        # あるいはファイルタイプ != "full" のときだけ標準化したい場合など、仕様に合わせて調整
        volume_index = ctx_mask_index[target_top_voxels]

    print(f"Number of selected voxels: {len(volume_index)}")

    return volume_index, weight_index_map, target_top_voxels


def load_keyword_pipeline(keywords_model: str):
    """
    Explanation:
    Sets up a pipeline for keyword extraction if using a large language model
    (such as GPT or Llama). Returns None if not using an external pipeline.
    """
    if keywords_model == "Llama-3.1-70B-Instruct":
        pipe = pipeline(
            "text-generation", 
            model="meta-llama/Llama-3.1-70B-Instruct",     
            model_kwargs={"torch_dtype": torch.bfloat16, 
                          "quantization_config": {"load_in_8bit": True}},
            device_map="auto"
        )
    else: # For GPT-based API or default
        pipe = None
    return pipe

def get_layer_info_and_weight(
    subject_name: str,
    args,
    nsda: NSDAccess,
    score_root_path: str
):
    """
    Explanation:
    Determines the best (or given) layer for a subject, loads the 
    regression weight for that layer, and returns relevant information 
    including volume index, weight map, etc.
    """
    modality = args.modality
    modality_hparam = args.modality_hparam
    model_name = args.model_name
    file_type = args.voxel_selection[0]
    threshold = float(args.voxel_selection[1])

    filename = make_filename(args.reduce_dims[0:2])

    # Determine directory for the model scores
    model_score_dir = f"{score_root_path}/{subject_name}/scores/{modality}/{modality_hparam}/{model_name}"

    # Load the best layer
    if args.layer_selection == "best":
        target_best_cv_layer, _, _ = search_best_layer(model_score_dir, filename, select_topN="all")
    else:
        target_best_cv_layer = args.layer_selection

    try:
        target_best_cv_layer_num = int(target_best_cv_layer.replace("layer", ""))
    except:
        target_best_cv_layer_num = target_best_cv_layer.replace("layer", "")

    layer_path = f"{model_score_dir}/{target_best_cv_layer}"
    layer_weight = np.load(f"{layer_path}/coef_{filename}.npy")
    
    volume_index, weight_index_map, target_top_voxels = create_volume_index_and_weight_map(
        subject_name=subject_name,
        file_type=file_type,
        threshold=threshold,
        model_score_dir=model_score_dir,
        target_best_cv_layer=target_best_cv_layer,
        filename=filename,
        nsda=nsda,
        atlasnames=args.atlasname
    )

    return target_best_cv_layer, target_best_cv_layer_num, layer_weight, volume_index, weight_index_map

def load_reducer_projector(args, subject_name: str, target_best_cv_layer: str):
    """
    Explanation:
    Loads a dimensionality reduction projector if requested (reduce_dims != 'default'). 
    Returns None if no dimensionality reduction is used.
    """
    modality = args.modality
    modality_hparam = args.modality_hparam
    model_name = args.model_name
    stim_root_path = "./data/stim_features/nsd"
    filename = make_filename(args.reduce_dims[0:2])

    if args.reduce_dims[0] != "default":
        reducer_proj_path_np = f"{stim_root_path}/{modality}/{modality_hparam}/{model_name}/{target_best_cv_layer}/projector_{subject_name}_ave_{filename}.npy"
        reducer_proj_path_pkl = f"{stim_root_path}/{modality}/{modality_hparam}/{model_name}/{target_best_cv_layer}/projector_{subject_name}_ave_{filename}.pkl"
        try:
            reducer_projector = np.load(reducer_proj_path_np, allow_pickle=True).item()
        except:
            reducer_projector = np.load(reducer_proj_path_pkl, allow_pickle=True)
    else:
        reducer_projector = None

    return reducer_projector

def create_and_check_temp_file(voxel_index, resp_save_path, args):
    """
    Explanation:
    Creates a temporary file indicating the voxel is being processed.
    If such a file already exists, it implies parallel processing or 
    that the voxel is already processed, so we skip further work.
    Returns a path to the temporary file or None if we should skip.
    """
    temp_file_path = (
        f"{resp_save_path}/temp_vcaption_{args.caption_model}"
        f"_kmodel_{args.keywords_model}_{args.key_num}keys_{args.candidate_num}cands_cmodel_{args.correct_model}.txt"
    )

    if os.path.exists(temp_file_path):
        print(f"Simulation for voxel {voxel_index} is being processed in parallel.")
        return None  # Signal to skip

    # Otherwise, create a new temp file
    open(temp_file_path, 'a').close()
    return temp_file_path


def check_existing_keywords_file(voxel_index, keywords_file_path, temp_file_path, args):
    """
    Explanation:
    Checks if a keywords/caption file already exists for this voxel.
    If it exists and contains a valid text or best_text, we skip processing.
    """
    if not os.path.exists(keywords_file_path):
        return False  # No file to check, so continue
    
    try:
        with open(keywords_file_path, "r") as f:
            keys_and_text = json.load(f)
            text = keys_and_text.get("text", "")
            best_text = keys_and_text.get("best_text", "")
        
        # If there's already a non-empty text, skip
        if "MeaCap" in args.caption_model:
            if text:
                print(f"Already processed: {keywords_file_path}")
                os.remove(temp_file_path)
                return True
        elif "gpt-4o-mini" in args.caption_model:
            # Additional logic for other caption models
            if best_text:
                print(f"Already processed: {keywords_file_path}")
                os.remove(temp_file_path)
                return True
    except:
        pass

    return False


def find_existing_keywords_file_if_needed(resp_save_path, keywords_file_path, args):
    """
    Explanation:
    If the main keywords file does not exist, tries to find a similar file
    containing the needed keywords by searching certain patterns in the directory.
    Returns any loaded `keys_and_text` (or empty dict if none found).
    """
    if os.path.exists(keywords_file_path):
        return {}  # We don't need to search if the file exists

    if "gpt-4" in args.keywords_model or "Llama" in args.keywords_model:
        search_strings = [
            args.keywords_model, 
            f"{args.candidate_num}cands", 
            f"{args.key_num}keys"
        ]
        files = [
            f for f in os.listdir(resp_save_path) 
            if os.path.isfile(os.path.join(resp_save_path, f))
        ]
        for file_name in files:
            if all(s in file_name for s in search_strings):
                file_path = os.path.join(resp_save_path, file_name)
                try:
                    with open(file_path, 'r', encoding='utf-8') as file:
                        keys_and_text = json.load(file)
                    print(f"File '{file_path}' includes the keywords.")
                    return keys_and_text
                except Exception as e:
                    print(f"Error reading file '{file_path}': {e}")
        print(f"File not found: {keywords_file_path}")

    return {}


def load_top_captions(resp_save_path, args):
    """
    Explanation:
    Loads the top-100 captions from the dataset directory to use 
    in generating or filtering keywords.
    """
    top_stim_dir = f"{resp_save_path}/stim_top100"
    top_pattern = re.compile(r"top(\d+)_.*")

    sorted_stim_base_paths = sorted(
        (p for p in os.listdir(top_stim_dir) if top_pattern.match(p)),
        key=lambda x: int(top_pattern.match(x).group(1))
    )

    # Build caption paths
    sorted_caps_base_paths = [
        re.sub(r'top\d+_(segment_\d+)_', r'\1/', s).replace(".jpg", "") 
        for s in sorted_stim_base_paths
    ]
    sorted_caps_base_paths = [
        base_path.split("/")[0] + "/" + "caption_" + base_path.split("/")[1] + "_" + args.dataset_captioner + ".txt" 
        for base_path in sorted_caps_base_paths
    ]
    sorted_caps_paths = [f"{args.dataset_path}/{cap_base_path}" for cap_base_path in sorted_caps_base_paths]

    captions_list = []
    for cap_path in sorted_caps_paths:
        with open(cap_path, "r") as f:
            cap = f.read()
            captions_list.append(cap)

    return captions_list


def save_top_captions_to_json(captions_list, resp_save_path, args):
    """
    Explanation:
    Saves the loaded top-100 captions to a JSON file for reference.
    """
    captions_file_path = os.path.join(
        resp_save_path, f"top100_captions_{args.dataset_captioner}.json"
    )
    with open(captions_file_path, "w") as f:
        json.dump(captions_list, f, indent=4)
